{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huseinhouse-storage.s3-ap-southeast-1.amazonaws.com/bert-bahasa/albert-tiny-2020-04-17.tar.gz\n",
    "# !tar -zxf albert-tiny-2020-04-17.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/lamb_optimizer.py:34: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from albert import modeling\n",
    "from albert import optimization\n",
    "from albert import tokenization\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/tokenization.py:240: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:loading sentence piece model\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file='albert-tiny-2020-04-17/sp10m.cased.v10.vocab', do_lower_case=False,\n",
    "      spm_model_file='albert-tiny-2020-04-17/sp10m.cased.v10.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:116: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<albert.modeling.AlbertConfig at 0x7f9c4ac759e8>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "albert_config = modeling.AlbertConfig.from_json_file('albert-tiny-2020-04-17/config.json')\n",
    "albert_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rules import normalized_chars\n",
    "import random\n",
    "from unidecode import unidecode\n",
    "import re\n",
    "\n",
    "laughing = {\n",
    "    'huhu',\n",
    "    'haha',\n",
    "    'gagaga',\n",
    "    'hihi',\n",
    "    'wkawka',\n",
    "    'wkwk',\n",
    "    'kiki',\n",
    "    'keke',\n",
    "    'huehue',\n",
    "    'hshs',\n",
    "    'hoho',\n",
    "    'hewhew',\n",
    "    'uwu',\n",
    "    'sksk',\n",
    "    'ksks',\n",
    "    'gituu',\n",
    "    'gitu',\n",
    "    'mmeeooww',\n",
    "    'meow',\n",
    "    'alhamdulillah',\n",
    "    'muah',\n",
    "    'mmuahh',\n",
    "    'hehe',\n",
    "    'salamramadhan',\n",
    "    'happywomensday',\n",
    "    'jahagaha',\n",
    "    'ahakss',\n",
    "    'ahksk'\n",
    "}\n",
    "\n",
    "def make_cleaning(s, c_dict):\n",
    "    s = s.translate(c_dict)\n",
    "    return s\n",
    "\n",
    "def cleaning(string):\n",
    "    \"\"\"\n",
    "    use by any transformer model before tokenization\n",
    "    \"\"\"\n",
    "    string = unidecode(string)\n",
    "    \n",
    "    string = ' '.join(\n",
    "        [make_cleaning(w, normalized_chars) for w in string.split()]\n",
    "    )\n",
    "    string = re.sub('\\(dot\\)', '.', string)\n",
    "    string = (\n",
    "        re.sub(re.findall(r'\\<a(.*?)\\>', string)[0], '', string)\n",
    "        if (len(re.findall(r'\\<a (.*?)\\>', string)) > 0)\n",
    "        and ('href' in re.findall(r'\\<a (.*?)\\>', string)[0])\n",
    "        else string\n",
    "    )\n",
    "    string = re.sub(\n",
    "        r'\\w+:\\/{2}[\\d\\w-]+(\\.[\\d\\w-]+)*(?:(?:\\/[^\\s/]*))*', ' ', string\n",
    "    )\n",
    "    \n",
    "    chars = '.,/'\n",
    "    for c in chars:\n",
    "        string = string.replace(c, f' {c} ')\n",
    "        \n",
    "    string = re.sub(r'[ ]+', ' ', string).strip().split()\n",
    "    string = [w for w in string if w[0] != '@']\n",
    "    x = []\n",
    "    for word in string:\n",
    "        word = word.lower()\n",
    "        if any([laugh in word for laugh in laughing]):\n",
    "            if random.random() >= 0.5:\n",
    "                x.append(word)\n",
    "        else:\n",
    "            x.append(word)\n",
    "    string = [w.title() if w[0].isupper() else w for w in x]\n",
    "    return ' '.join(string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['anger', 'fear', 'happy', 'love', 'sadness', 'surprise'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "import random\n",
    "\n",
    "emotion_label = ['anger', 'fear', 'happy', 'love', 'sadness', 'surprise']\n",
    "\n",
    "with open('../sentiment/emotion-twitter-lexicon.json') as fopen:\n",
    "    emotion = json.load(fopen)\n",
    "    \n",
    "emotion.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anger 108813\n",
      "fear 20316\n",
      "happy 30962\n",
      "love 20783\n",
      "sadness 26468\n",
      "surprise 13107\n"
     ]
    }
   ],
   "source": [
    "for k, v in emotion.items():\n",
    "    print(k, len(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "anger 30000\n",
      "fear 20316\n",
      "happy 30000\n",
      "love 20783\n",
      "sadness 26468\n",
      "surprise 13107\n"
     ]
    }
   ],
   "source": [
    "texts, labels = [], []\n",
    "\n",
    "for k, v in emotion.items():\n",
    "    if len(v) > 30000:\n",
    "        emotion[k] = random.sample(v, 30000)\n",
    "    print(k, len(emotion[k]))\n",
    "    texts.extend(emotion[k])\n",
    "    labels.extend([emotion_label.index(k)] * len(emotion[k]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 140674/140674 [00:12<00:00, 11404.78it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for i in tqdm(range(len(texts))):\n",
    "    texts[i] = cleaning(texts[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 140674/140674 [00:00<00:00, 1332427.45it/s]\n"
     ]
    }
   ],
   "source": [
    "actual_t, actual_l = [], []\n",
    "\n",
    "for i in tqdm(range(len(texts))):\n",
    "    if len(texts[i]) > 2:\n",
    "        actual_t.append(texts[i])\n",
    "        actual_l.append(labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 140674/140674 [00:19<00:00, 7343.76it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "input_ids, input_masks = [], []\n",
    "\n",
    "for text in tqdm(actual_t):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 3\n",
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(texts) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_initializer(initializer_range=0.02):\n",
    "    return tf.truncated_normal_initializer(stddev=initializer_range)\n",
    "\n",
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "        training = True,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.MASK = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.AlbertModel(\n",
    "            config=albert_config,\n",
    "            is_training=training,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.MASK,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_sequence_output()\n",
    "        output_layer = tf.layers.dense(\n",
    "            output_layer,\n",
    "            albert_config.hidden_size,\n",
    "            activation=tf.tanh,\n",
    "            kernel_initializer=create_initializer())\n",
    "        self.logits_seq = tf.layers.dense(output_layer, dimension_output,\n",
    "                                         kernel_initializer=create_initializer())\n",
    "        self.logits_seq = tf.identity(self.logits_seq, name = 'logits_seq')\n",
    "        self.logits = self.logits_seq[:, 0]\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_INIT_CHKPNT = 'albert-tiny-2020-04-17/model.ckpt-1000000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:194: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:507: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:588: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:1025: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/modeling.py:253: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:36: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:41: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "INFO:tensorflow:++++++ warmup starts at step 0, for 703 steps ++++++\n",
      "INFO:tensorflow:using adamw\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/albert/optimization.py:101: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/math_grad.py:1375: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "INFO:tensorflow:Restoring parameters from albert-tiny-2020-04-17/model.ckpt-1000000\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 6\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_input_ids, test_input_ids, train_Y, test_Y, train_mask, test_mask = train_test_split(\n",
    "    input_ids, actual_l, input_masks, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 1876/1876 [01:30<00:00, 20.76it/s, accuracy=0.974, cost=0.138] \n",
      "test minibatch loop: 100%|██████████| 469/469 [00:08<00:00, 53.57it/s, accuracy=1, cost=0.00849]   \n",
      "train minibatch loop:   0%|          | 3/1876 [00:00<01:28, 21.11it/s, accuracy=1, cost=0.00833]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, training loss: 0.284124, training acc: 0.924880, valid loss: 0.051377, valid acc: 0.987598\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 1876/1876 [01:29<00:00, 20.98it/s, accuracy=0.974, cost=0.166] \n",
      "test minibatch loop: 100%|██████████| 469/469 [00:08<00:00, 54.18it/s, accuracy=1, cost=0.00195]   \n",
      "train minibatch loop:   0%|          | 2/1876 [00:00<01:35, 19.72it/s, accuracy=1, cost=0.00162]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, training loss: 0.035439, training acc: 0.991529, valid loss: 0.034821, valid acc: 0.991684\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 1876/1876 [01:29<00:00, 20.98it/s, accuracy=0.974, cost=0.037] \n",
      "test minibatch loop: 100%|██████████| 469/469 [00:08<00:00, 54.36it/s, accuracy=1, cost=0.00078]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, training loss: 0.019873, training acc: 0.995367, valid loss: 0.029992, valid acc: 0.993461\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "for EPOCH in range(epoch):\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = train_input_ids[i: index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_mask = train_mask[i: index]\n",
    "        batch_mask = pad_sequences(batch_mask, padding='post')\n",
    "        batch_y = train_Y[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.MASK: batch_mask\n",
    "            },\n",
    "        )\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = test_input_ids[i: index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_mask = test_mask[i: index]\n",
    "        batch_mask = pad_sequences(batch_mask, padding='post')\n",
    "        batch_y = test_Y[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.MASK: batch_mask\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "    \n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'albert-tiny-emotion/model.ckpt'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'albert-tiny-emotion/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:++++++ warmup starts at step 0, for 703 steps ++++++\n",
      "INFO:tensorflow:using adamw\n",
      "INFO:tensorflow:Restoring parameters from albert-tiny-emotion/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 6\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate,\n",
    "    training = False\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.restore(sess, 'albert-tiny-emotion/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'bert/embeddings/word_embeddings',\n",
       " 'bert/embeddings/token_type_embeddings',\n",
       " 'bert/embeddings/position_embeddings',\n",
       " 'bert/embeddings/LayerNorm/gamma',\n",
       " 'bert/encoder/embedding_hidden_mapping_in/kernel',\n",
       " 'bert/encoder/embedding_hidden_mapping_in/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/key/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/value/bias',\n",
       " 'bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/attention_1/output/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/LayerNorm/gamma',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/kernel',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/bias',\n",
       " 'bert/encoder/transformer/group_0/inner_group_0/LayerNorm_1/gamma',\n",
       " 'bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs',\n",
       " 'bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs',\n",
       " 'bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs',\n",
       " 'bert/pooler/dense/kernel',\n",
       " 'bert/pooler/dense/bias',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'dense_1/kernel',\n",
       " 'dense_1/bias',\n",
       " 'logits_seq',\n",
       " 'logits',\n",
       " 'gradients/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs_grad/mul',\n",
       " 'gradients/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs_grad/Sum/reduction_indices',\n",
       " 'gradients/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs_grad/Sum',\n",
       " 'gradients/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs_grad/sub',\n",
       " 'gradients/bert/encoder/transformer/group_0_3/layer_3/inner_group_0/attention_1/self/attention_probs_grad/mul_1',\n",
       " 'gradients/bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs_grad/mul',\n",
       " 'gradients/bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs_grad/Sum/reduction_indices',\n",
       " 'gradients/bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs_grad/Sum',\n",
       " 'gradients/bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs_grad/sub',\n",
       " 'gradients/bert/encoder/transformer/group_0_2/layer_2/inner_group_0/attention_1/self/attention_probs_grad/mul_1',\n",
       " 'gradients/bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs_grad/mul',\n",
       " 'gradients/bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs_grad/Sum/reduction_indices',\n",
       " 'gradients/bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs_grad/Sum',\n",
       " 'gradients/bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs_grad/sub',\n",
       " 'gradients/bert/encoder/transformer/group_0_1/layer_1/inner_group_0/attention_1/self/attention_probs_grad/mul_1',\n",
       " 'gradients/bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs_grad/mul',\n",
       " 'gradients/bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs_grad/Sum/reduction_indices',\n",
       " 'gradients/bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs_grad/Sum',\n",
       " 'gradients/bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs_grad/sub',\n",
       " 'gradients/bert/encoder/transformer/group_0/layer_0/inner_group_0/attention_1/self/attention_probs_grad/mul_1']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'attention_probs' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from albert-tiny-emotion/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-23-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 29 variables.\n",
      "INFO:tensorflow:Converted 29 variables to const ops.\n",
      "2826 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('albert-tiny-emotion', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "\n",
    "bucketName = 'huseinhouse-storage'\n",
    "Key = 'albert-tiny-emotion/frozen_model.pb'\n",
    "outPutname = \"v34/emotion/albert-tiny-emotion.pb\"\n",
    "\n",
    "s3 = boto3.client('s3')\n",
    "s3.upload_file(Key,bucketName,outPutname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 469/469 [00:07<00:00, 60.19it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_input_ids), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_input_ids))\n",
    "    batch_x = test_input_ids[i: index]\n",
    "    batch_x = pad_sequences(batch_x, padding='post')\n",
    "    batch_mask = test_mask[i: index]\n",
    "    batch_mask = pad_sequences(batch_mask, padding='post')\n",
    "    batch_y = test_Y[i: index]\n",
    "    \n",
    "    predict_Y += np.argmax(sess.run(model.logits,\n",
    "            feed_dict = {\n",
    "            model.X: batch_x,\n",
    "            model.MASK: batch_mask\n",
    "            },\n",
    "    ), 1, ).tolist()\n",
    "    real_Y += batch_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       anger    0.99396   0.98603   0.98998      6012\n",
      "        fear    0.99390   0.99512   0.99451      4096\n",
      "       happy    0.99652   0.99652   0.99652      6030\n",
      "        love    0.99114   0.99187   0.99150      4059\n",
      "     sadness    0.99121   0.99699   0.99409      5316\n",
      "    surprise    0.99278   0.99619   0.99448      2622\n",
      "\n",
      "    accuracy                        0.99346     28135\n",
      "   macro avg    0.99325   0.99378   0.99351     28135\n",
      "weighted avg    0.99346   0.99346   0.99346     28135\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, predict_Y, target_names = ['anger', 'fear', 'happy', 'love', 'sadness', 'surprise'],\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
